import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, ConfusionMatrixDisplay


# Load the spreadsheet
data = pd.read_csv("Data/gpt_classifications.csv")

# Extract true labels and predicted labels
true_labels = data['majortopic'].dropna().astype(int).tolist()
gpt35_predictions = data['gpt35_num_label'].dropna().astype(int).tolist()
gpt4_predictions = data['gpt4_num_label'].dropna().astype(int).tolist()

# Since the lists might have different lengths due to dropping NaN values,
# we will consider only the indices where all three lists have values.
common_indices = set(data['majortopic'].dropna().index) & \
                 set(data['gpt35_num_label'].dropna().index) & \
                 set(data['gpt4_num_label'].dropna().index)

# Filter the lists based on common indices
true_labels = [true_labels[i] for i in common_indices]
gpt35_predictions = [gpt35_predictions[i] for i in common_indices]
gpt4_predictions = [gpt4_predictions[i] for i in common_indices]

def calculate_and_display_statistics(true_labels, gpt35_predictions, gpt4_predictions, scenario_name):
    print(f"Classification Statistics for {scenario_name}:\n")

    # GPT 3.5 Turbo
    print("GPT 3.5 Turbo:")
    print(classification_report(true_labels, gpt35_predictions, zero_division=1))
    print(f"Accuracy: {accuracy_score(true_labels, gpt35_predictions):.2f}\n")

    # Plot normalized confusion matrix for GPT 3.5 Turbo
    cm = confusion_matrix(true_labels, gpt35_predictions)
    unique_labels = list(set(true_labels).union(set(gpt35_predictions)))
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(10, 10))
    sns.heatmap(cm_normalized, cmap="Blues", annot=False, cbar=True, xticklabels=unique_labels, yticklabels=unique_labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'GPT 3.5 Turbo Normalized Confusion Matrix for {scenario_name}')
    plt.show()

    # GPT 4
    print("GPT 4:")
    print(classification_report(true_labels, gpt4_predictions, zero_division=1))
    print(f"Accuracy: {accuracy_score(true_labels, gpt4_predictions):.2f}\n")

    # Plot normalized confusion matrix for GPT 4
    cm = confusion_matrix(true_labels, gpt4_predictions)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(10, 10))
    sns.heatmap(cm_normalized, cmap="Blues", annot=False, cbar=True, xticklabels=unique_labels, yticklabels=unique_labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'GPT 4 Normalized Confusion Matrix for {scenario_name}')
    plt.show()

    print("="*80)



# Scenario 1
calculate_and_display_statistics(true_labels, gpt35_predictions, gpt4_predictions, 'Scenario 1')

# Scenario 2
excluded_indices = [i for i, pred in enumerate(gpt35_predictions) if pred != 199]
true_labels_scenario_1 = [true_labels[i] for i in excluded_indices]
gpt35_predictions_scenario_1 = [gpt35_predictions[i] for i in excluded_indices]
gpt4_predictions_scenario_1 = [gpt4_predictions[i] for i in excluded_indices]

calculate_and_display_statistics(true_labels_scenario_1, gpt35_predictions_scenario_1, gpt4_predictions_scenario_1, 'Scenario 2')

# Scenario 3
common_predictions_indices = [i for i, (pred35, pred4) in enumerate(zip(gpt35_predictions, gpt4_predictions)) if pred35 == pred4]
true_labels_scenario_2 = [true_labels[i] for i in common_predictions_indices]
gpt35_predictions_scenario_2 = [gpt35_predictions[i] for i in common_predictions_indices]
gpt4_predictions_scenario_2 = [gpt4_predictions[i] for i in common_predictions_indices] 

calculate_and_display_statistics(true_labels_scenario_2, gpt35_predictions_scenario_2, gpt4_predictions_scenario_2, 'Scenario 3')